function OUT = TUSIM(SEED,NREP,DGP)
% Simulation code.  Written for brute force parallelization.  
% Objects to be estimated are set within the function.
% Last edit: 3/28/17 CBH
% OUT = TUSIM(SEED,NREP,n,p)
% SEED - seed for random number generator
% NREP - number of replications
% n - simulation sample size
% k - number of x variables (excluding the intercept and treatment dummy)
% s - number of relevant x variables (excluding the intercept and treatment
% dummy)
% Simulation model:
% y = a0 + x'*a1 + b0*d + (x.*d)'*b1 + e
% Returns inferential quantities for ...

% set random number seed
rng(SEED);

% Main simulation loop
% Preallocate objects of interest
% Point estimates
b_hat = zeros(NREP,6);
te_hat = zeros(NREP,5);
pi_hat = zeros(NREP,5);

% Standard errors for conventional estimators
b_se = zeros(NREP,6);
te_se = zeros(NREP,5);
pi_se = zeros(NREP,5);

% Interval estimates
nstep = floor(sqrt(DGP.n));
b_lower = zeros(NREP,nstep);
b_upper = zeros(NREP,nstep);
pi_lower = zeros(NREP,nstep);
pi_upper = zeros(NREP,nstep);
te_lower = zeros(NREP,nstep);
te_upper = zeros(NREP,nstep);

Coverage_b = zeros(NREP,nstep);
Coverage_pi = zeros(NREP,nstep);
Coverage_te = zeros(NREP,nstep);

Length_b = zeros(NREP,nstep);
Length_pi = zeros(NREP,nstep);
Length_te = zeros(NREP,nstep);

for ii = 1:NREP
    fprintf('Replication # %d \n',ii);

    % Generate data
    x = randn(DGP.n,DGP.k)*DGP.S;
    X = (x - ones(DGP.n,1)*DGP.trunc').*(x > ones(DGP.n,1)*DGP.trunc');
    d = rand(DGP.n,1) > DGP.prop;
    w = X.*(d*ones(1,DGP.k));
    z = [ones(DGP.n,1),X,d,w];
    meany = DGP.a0 + X*DGP.a1 + d*DGP.b0 + w*DGP.b1;
    stdy = sqrt((meany.^2)/DGP.scaleNorm);
    y = meany + stdy.*randn(DGP.n,1);
    pz = size(z,2);
    
    % Treatment and control subsets
    yd0 = y(d==0);
    yd1 = y(d==1);
    Xd0 = X(d==0,:);
    Xd1 = X(d==1,:);
    zd0 = [ones(sum(d==0),1) Xd0];
    zd1 = [ones(sum(d),1) Xd1];
    
    % Estimate models
    % All regressors
    if pz < DGP.n
        b.all = z\y;
        [se.all,V.all] = hetero_se(z,y-z*b.all,inv(z'*z));
    end
    
    % Initialize for estimators that do not return estimate for all
    % parameters
    b.true = zeros(pz,1);
    se.true = zeros(pz,1);
    V.true = zeros(pz,pz);
    
    b.pl = zeros(pz,1);
    se.pl = zeros(pz,1);
    V.pl = zeros(pz,pz);
    
    % True model
    zT = z(:,DGP.TRUE);
    coeftrue = zT\y;
    [setrue,Vtrue] = hetero_se(zT,y-zT*coeftrue,inv(zT'*zT));
    b.true(DGP.TRUE) = coeftrue;
    se.true(DGP.TRUE) = setrue;
    V.true(DGP.TRUE,DGP.TRUE) = Vtrue;
    
    % Lasso with plug-in tuning
    % Control observations
    % Initialize with 5 most marginally correlated variables
    czy0 = corr(yd0,Xd0);
    czys = sort(-abs(czy0));
    suppy0 = find(abs(czy0) >= -czys(5));
    tmpy0 = (Xd0(:,suppy0)-ones(sum(d==0),1)*mean(Xd0(:,suppy0)))\(yd0-mean(yd0));
    by0 = zeros(size(Xd0,2),1);
    by0(suppy0) = tmpy0;
    blas0 = feasiblePostLasso(yd0-mean(yd0),Xd0-ones(sum(d==0),1)*mean(Xd0),...
        'MaxIter',1,'beta0',by0);
    blas0 = [(mean(yd0)-mean(Xd0)*blas0);blas0]; %#ok<*AGROW>
    use0 = blas0 ~= 0;
    Vlas0 = zeros(numel(blas0));
    [~,Vtmp0] = hetero_se(zd0(:,use0),yd0-zd0*blas0,inv(zd0(:,use0)'*zd0(:,use0)));
    Vlas0(use0,use0) = Vtmp0;
    
    % Treatment observations
    czy1 = corr(yd1,Xd1);
    czys = sort(-abs(czy1));
    suppy1 = find(abs(czy1) >= -czys(5));
    tmpy1 = (Xd1(:,suppy1)-ones(sum(d),1)*mean(Xd1(:,suppy1)))\(yd1-mean(yd1));
    by1 = zeros(size(Xd1,2),1);
    by1(suppy1) = tmpy1;
    blas1 = feasiblePostLasso(yd1-mean(yd1),Xd1-ones(sum(d),1)*mean(Xd1),...
        'MaxIter',1,'beta0',by1);
    blas1 = [(mean(yd1)-mean(Xd1)*blas1);blas1];
    use1 = blas1 ~= 0;
    Vlas1 = zeros(numel(blas1));
    [~,Vtmp1] = hetero_se(zd1(:,use1),yd1-zd1*blas1,inv(zd1(:,use1)'*zd1(:,use1)));
    Vlas1(use1,use1) = Vtmp1;
    
    % Overall lasso results
    b.lasp = [blas0;blas1-blas0];
    V.lasp = blkdiag(Vlas0,Vlas1+Vlas0);
    se.lasp = sqrt(diag(V.lasp));
    
    % Post-lasso
    usepl = union(find(use0),find(use1));
    usepl = [usepl;(usepl+DGP.k+1)];
    zpl = z(:,usepl);
    bpl = zpl\y;
    [sepl,Vpl] = hetero_se(zpl,y-zpl*bpl,inv(zpl'*zpl));   
    b.pl(usepl) = bpl;
    se.pl(usepl) = sepl;
    V.pl(usepl,usepl) = Vpl;
    
    % Lasso with CV tuning
    % Control observations
    [lcv0,slcv0] = lasso(Xd0,yd0,'CV',10);
    blascv0 = [slcv0.Intercept(slcv0.IndexMinMSE) ; lcv0(:,slcv0.IndexMinMSE)];
    usecv0 = blascv0 ~= 0;
    Vlascv0 = zeros(numel(blascv0));
    [~,Vtmpcv0] = hetero_se(zd0(:,usecv0),yd0-zd0*blascv0,inv(zd0(:,usecv0)'*zd0(:,usecv0)));
    Vlascv0(usecv0,usecv0) = Vtmpcv0;
    
    % Treatment observations
    [lcv1,slcv1] = lasso(Xd1,yd1,'CV',10);
    blascv1 = [slcv1.Intercept(slcv1.IndexMinMSE) ; lcv1(:,slcv1.IndexMinMSE)];
    usecv1 = blascv1 ~= 0;
    Vlascv1 = zeros(numel(blascv1));
    [~,Vtmpcv1] = hetero_se(zd1(:,usecv1),yd1-zd1*blascv1,inv(zd1(:,usecv1)'*zd1(:,usecv1)));
    Vlascv1(usecv1,usecv1) = Vtmpcv1;
    
    % Overall lasso results
    b.lascv = [blascv0;blascv1-blascv0];
    V.lascv = blkdiag(Vlascv0,Vlascv1+Vlascv0);
    se.lascv = sqrt(diag(V.lascv));    
    
    
    % Inference for objects of interest
    % Coefficient on first interaction term (b1(1))
    % Index for first interaction term
    b11index = DGP.k+3;
    
    % Double selection
    [b.ds,se.ds] = PostDouble(y,z(:,b11index),z(:,setdiff(1:size(z,2),b11index)));
    
    % Point estimates
    if pz < DGP.n
        b_hat(ii,:) = [b.all(b11index), b.true(b11index), b.ds, b.lasp(b11index), b.pl(b11index), b.lascv(b11index)];
        b_se(ii,:) = [se.all(b11index), se.true(b11index), se.ds, se.lasp(b11index), se.pl(b11index), se.lascv(b11index)];
    else
        b_hat(ii,:) = [NaN, b.true(b11index), b.ds, b.lasp(b11index), b.pl(b11index), b.lascv(b11index)];
        b_se(ii,:) = [NaN, se.true(b11index), se.ds, se.lasp(b11index), se.pl(b11index), se.lascv(b11index)];
    end        
    
    % Targeted undersmoothing
    a_coef = zeros(size(b.true));
    a_coef(b11index) = 1;
    [b_lower(ii,:),b_upper(ii,:)] = fsel_linear(z,y,a_coef,usepl,nstep);
    Coverage_b(ii,:) = b_lower(ii,:) < DGP.coef(b11index) & b_upper(ii,:) > DGP.coef(b11index);
    Length_b(ii,:) = b_upper(ii,:) - b_lower(ii,:);
    
    
    % CATE(x0)
    a_te = DGP.x0;
    
    % Point estimates
    if pz < DGP.n
        te_hat(ii,:) = a_te'*[b.all, b.true, b.lasp, b.pl, b.lascv];
        te_se(ii,:) = sqrt([a_te'*V.all*a_te , a_te'*V.true*a_te , a_te'*V.lasp*a_te , ...
            a_te'*V.pl*a_te , a_te'*V.lascv*a_te]);
    else
        te_hat(ii,:) = [NaN , a_te'*[b.true, b.lasp, b.pl, b.lascv]];
        te_se(ii,:) = sqrt([NaN , a_te'*V.true*a_te , a_te'*V.lasp*a_te , ...
            a_te'*V.pl*a_te , a_te'*V.lascv*a_te]);
    end        
    
    % Targeted undersmoothing
    [te_lower(ii,:),te_upper(ii,:)] = fsel_linear(z,y,a_te,usepl,nstep);
    Coverage_te(ii,:) = te_lower(ii,:) < DGP.te0 & te_upper(ii,:) > DGP.te0;
    Length_te(ii,:) = te_upper(ii,:) - te_lower(ii,:);    
    
    
    % Expected profit per individual
    Xc = [zeros(DGP.n,DGP.k+1) , ones(DGP.n,1) , X];
    
    if pz < DGP.n
        pi_hat(ii,:) = [TUSIM_prof(b.all,Xc) , TUSIM_prof(b.true,Xc) , ...
            TUSIM_prof(b.lasp,Xc) , TUSIM_prof(b.pl,Xc) , TUSIM_prof(b.lascv,Xc)];
        pi_se(ii,:) = [TUSIM_profse(b.all,Xc,1e-6,z,(1:size(z,2))',y-z*b.all) , ...
            TUSIM_profse(b.true,Xc,1e-6,z,find(DGP.TRUE),y-z*b.true) , ...
            TUSIM_profse(b.lasp,Xc,1e-6,z,[find(use0);find(use1)+DGP.k+1],y-z*b.lasp) , ...
            TUSIM_profse(b.pl,Xc,1e-6,z,usepl,y-z*b.pl) , ...
            TUSIM_profse(b.lascv,Xc,1e-6,z,[find(usecv0);find(usecv1)+DGP.k+1],y-z*b.lascv)];
    else
        pi_hat(ii,:) = [NaN , TUSIM_prof(b.true,Xc) , ...
            TUSIM_prof(b.lasp,Xc) , TUSIM_prof(b.pl,Xc) , TUSIM_prof(b.lascv,Xc)];
        pi_se(ii,:) = [NaN , ...
            TUSIM_profse(b.true,Xc,1e-6,z,find(DGP.TRUE),y-z*b.true) , ...
            TUSIM_profse(b.lasp,Xc,1e-6,z,[find(use0);find(use1)+DGP.k+1],y-z*b.lasp) , ...
            TUSIM_profse(b.pl,Xc,1e-6,z,usepl,y-z*b.pl) , ...
            TUSIM_profse(b.lascv,Xc,1e-6,z,[find(usecv0);find(usecv1)+DGP.k+1],y-z*b.lascv)];
    end        
    
    % Targeted undersmoothing
    [pi_lower(ii,:),pi_upper(ii,:)] = fsel_pi(z,y,Xc,usepl,nstep);
    Coverage_pi(ii,:) = pi_lower(ii,:) < DGP.pi0 & pi_upper(ii,:) > DGP.pi0;
    Length_pi(ii,:) = pi_upper(ii,:) - pi_lower(ii,:);        
    
end
OUT.DGP = DGP;

OUT.RegCoef.true = DGP.coef(b11index);
OUT.RegCoef.b = b_hat;
OUT.RegCoef.se = b_se;
OUT.RegCoef.fslower = b_lower;
OUT.RegCoef.fsupper = b_upper;
OUT.RegCoef.cover = Coverage_b;
OUT.RegCoef.length = Length_b;

OUT.TE.true = DGP.te0;
OUT.TE.te = te_hat;
OUT.TE.se = te_se;
OUT.TE.fslower = te_lower;
OUT.TE.fsupper = te_upper;
OUT.TE.cover = Coverage_te;
OUT.TE.length = Length_te;

OUT.PI.true = DGP.pi0;
OUT.PI.te = pi_hat;
OUT.PI.se = pi_se;
OUT.PI.fslower = pi_lower;
OUT.PI.fsupper = pi_upper;
OUT.PI.cover = Coverage_pi;
OUT.PI.length = Length_pi;


